/**********************************************************************
* Copyright (c) 2013-2014  Red Hat, Inc.
*
* Developed by Daynix Computing LTD.
*
* Authors:
*     Dmitry Fleytman <dmitry@daynix.com>
*     Pavel Gurvich <pavel@daynix.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**********************************************************************/

#include "stdafx.h"
#include "DeviceAccess.h"
#include "trace.h"
#include "Irp.h"
#include "RegText.h"
#include "DeviceAccess.tmh"

#if TARGET_OS_WIN_XP
#include <initguid.h>
#include <Usbbusif.h>
#endif

#if !TARGET_OS_WIN_XP
bool CWdmUSBD::Create()
{
    auto status = IoCreateDevice(m_Driver, 0, NULL, FILE_DEVICE_UNKNOWN, FILE_AUTOGENERATED_DEVICE_NAME, FALSE, &m_USBDDevice);
    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! IoCreateDevice failed, %!STATUS!", status);
        return false;
    }

    m_AttachmentPoint = IoAttachDeviceToDeviceStack(m_USBDDevice, m_TargetDevice);
    if (m_AttachmentPoint == nullptr)
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! IoAttachDeviceToDeviceStack returned NULL");
        return false;
    }

    status = USBD_CreateHandle(m_USBDDevice, m_AttachmentPoint, USBD_CLIENT_CONTRACT_VERSION_602, 'DBHR', &m_USBDHandle);
    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! USBD_CreateHandle failed, %!STATUS!", status);
        return false;
    }

    return true;
}

CWdmUSBD::~CWdmUSBD()
{
    if (m_USBDHandle != nullptr)
    {
        USBD_CloseHandle(m_USBDHandle);
    }
    if (m_AttachmentPoint != nullptr)
    {
        IoDetachDevice(m_AttachmentPoint);
    }
    if (m_USBDDevice != nullptr)
    {
        IoDeleteDevice(m_USBDDevice);
    }
}
#endif

ULONG CWdmDeviceAccess::GetAddress()
{
    DEVICE_CAPABILITIES Capabilities;

    if (!NT_SUCCESS(QueryCapabilities(Capabilities)))
    {
        return NO_ADDRESS;
    }

    return Capabilities.Address;
}

PWCHAR CWdmDeviceAccess::QueryBusID(BUS_QUERY_ID_TYPE idType)
{
    CIrp irp;

    auto status = irp.Create(m_DevObj);

    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during IRP creation", status);
        return nullptr;
    }

    irp.Configure([idType] (PIO_STACK_LOCATION s)
                  {
                      s->MajorFunction = IRP_MJ_PNP;
                      s->MinorFunction = IRP_MN_QUERY_ID;
                      s->Parameters.QueryId.IdType = idType;
                  });

    status = irp.SendSynchronously();

    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during %!devid! query", status, idType);
        return nullptr;
    }

    PWCHAR idData;
    irp.ReadResult([&idData](ULONG_PTR information)
                   { idData = reinterpret_cast<PWCHAR>(information); });

    return (idData != nullptr) ? MakeNonPagedDuplicate(idType, idData) : nullptr;
}

NTSTATUS CWdmDeviceAccess::QueryCapabilities(DEVICE_CAPABILITIES &Capabilities)
{
    CIrp irp;

    auto status = irp.Create(m_DevObj);
    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during IRP creation", status);
        return status;
    }

    Capabilities = {};
    Capabilities.Size = sizeof(Capabilities);
    Capabilities.Version = 1;
    Capabilities.Address = 0xFFFFFFFF;
    Capabilities.UINumber = 0xFFFFFFFF;

    irp.Configure([&Capabilities](PIO_STACK_LOCATION s)
                  {
                      s->MajorFunction = IRP_MJ_PNP;
                      s->MinorFunction = IRP_MN_QUERY_CAPABILITIES;
                      s->Parameters.DeviceCapabilities.Capabilities = &Capabilities;
                  });

    status = irp.SendSynchronously();
    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during capabilities query", status);
    }

    return status;
}

SIZE_T CWdmDeviceAccess::GetIdBufferLength(BUS_QUERY_ID_TYPE idType, PWCHAR idData)
{
    switch (idType)
    {
    case BusQueryHardwareIDs:
    case BusQueryCompatibleIDs:
        return CRegMultiSz::GetBufferLength(idData) + sizeof(WCHAR);
    default:
        return CRegSz::GetBufferLength(idData);
    }
}

bool CWdmDeviceAccess::QueryPowerData(CM_POWER_DATA& powerData)
{
    powerData.PD_Size = sizeof(powerData);
#if !TARGET_OS_WIN_XP
    ULONG dummy;
    DEVPROPTYPE propType;
    auto status = IoGetDevicePropertyData(m_DevObj, &DEVPKEY_Device_PowerData, LOCALE_NEUTRAL, 0,
        sizeof(powerData), &powerData, &dummy, &propType);
    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS!", status);
    }
    return NT_SUCCESS(status);
#else
    return false;
#endif
}

static void PowerRequestCompletion(
    _In_ PDEVICE_OBJECT DeviceObject,
    _In_ UCHAR MinorFunction,
    _In_ POWER_STATE PowerState,
    _In_opt_ PVOID Context,
    _In_ PIO_STATUS_BLOCK IoStatus
)
{
    UNREFERENCED_PARAMETER(DeviceObject);
    UNREFERENCED_PARAMETER(MinorFunction);
    UNREFERENCED_PARAMETER(PowerState);
    UNREFERENCED_PARAMETER(IoStatus);
    CWdmEvent *pev = (CWdmEvent *)Context;
    TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_DEVACCESS, "%!FUNC! -> D%d", PowerState.DeviceState - 1);
    pev->Set();
}

PWCHAR CWdmDeviceAccess::MakeNonPagedDuplicate(BUS_QUERY_ID_TYPE idType, PWCHAR idData)
{
    auto bufferLength = GetIdBufferLength(idType, idData);

    auto newIdData = ExAllocatePoolWithTag(USBDK_NON_PAGED_POOL, bufferLength, 'IDHR');
    if (newIdData != nullptr)
    {
        RtlCopyMemory(newIdData, idData, bufferLength);
    }
    else
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Failed to allocate non-paged buffer for %!devid!", idType);
    }

    ExFreePool(idData);
    return static_cast<PWCHAR>(newIdData);
}

NTSTATUS CWdmDeviceAccess::QueryForInterface(const GUID &guid, __out INTERFACE &intf,
    USHORT intfSize, USHORT intfVer, __in_opt PVOID intfCtx)
{
    ASSERT(intfSize >= sizeof(INTERFACE));
    CIrp Irp;
    NTSTATUS status = Irp.Create(m_DevObj);
    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! IRP alloc failed");
        return status;
    }

    Irp.Configure([&](PIO_STACK_LOCATION s)
    {
        s->MajorFunction = IRP_MJ_PNP;
        s->MinorFunction = IRP_MN_QUERY_INTERFACE;
        auto q = &s->Parameters.QueryInterface;
        q->InterfaceType = &guid;
        q->Version = intfVer;
        q->Interface = &intf;
        q->Size = intfSize;
        q->InterfaceSpecificData = intfCtx;
        memset(q->Interface, 0, q->Size);
    });

    status = Irp.SendSynchronously();
    return status;
}

NTSTATUS CWdmUsbDeviceAccess::Reset(bool ForceD0)
{
    CIoControlIrp Irp;
    CM_POWER_DATA powerData;
    if (ForceD0 && QueryPowerData(powerData) && powerData.PD_MostRecentPowerState != PowerDeviceD0)
    {
        TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_DEVACCESS, "%!FUNC! device power state D%d", powerData.PD_MostRecentPowerState - 1);
        POWER_STATE PowerState;
        CWdmEvent Event;
        PowerState.DeviceState = PowerDeviceD0;
        auto status = PoRequestPowerIrp(m_DevObj, IRP_MN_SET_POWER, PowerState, PowerRequestCompletion, &Event, NULL);
        if (NT_SUCCESS(status))
        {
            Event.Wait();
        }
    }

    auto status = Irp.Create(m_DevObj, IOCTL_INTERNAL_USB_CYCLE_PORT);

    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during IOCTL IRP creation", status);
        return status;
    }

    status = Irp.SendSynchronously();

    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Send IOCTL IRP Error %!STATUS!", status);
    }

    return status;
}

NTSTATUS CWdmUsbDeviceAccess::GetDeviceDescriptor(USB_DEVICE_DESCRIPTOR &Descriptor)
{
    URB Urb;
    UsbDkBuildDescriptorRequest(Urb, USB_DEVICE_DESCRIPTOR_TYPE, 0, Descriptor);
    return UsbDkSendUrbSynchronously(m_DevObj, Urb);
}

NTSTATUS CWdmUsbDeviceAccess::GetConfigurationDescriptor(UCHAR Index, USB_CONFIGURATION_DESCRIPTOR &Descriptor, size_t Length)
{
    RtlZeroMemory(&Descriptor, Length);

    URB Urb;
    UsbDkBuildDescriptorRequest(Urb, USB_CONFIGURATION_DESCRIPTOR_TYPE, Index, Descriptor, static_cast<ULONG>(Length));

    auto status = UsbDkSendUrbSynchronously(m_DevObj, Urb);
    if (Descriptor.wTotalLength == 0)
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Invalid configuration descriptor on unknown size received.");
        return USBD_STATUS_INAVLID_CONFIGURATION_DESCRIPTOR;
    }

    if ((Descriptor.wTotalLength <= Length) && !NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Failed to retrieve the configuration descriptor.");
        return status;
    }

    return STATUS_SUCCESS;
}

USB_DK_DEVICE_SPEED UsbDkWdmUsbDeviceGetSpeed(PDEVICE_OBJECT DevObj, PDRIVER_OBJECT DriverObj)
{
#if !TARGET_OS_WIN_XP
    CWdmUSBD USBD(DriverObj, DevObj);

    if (!USBD.Create())
    {
        return NoSpeed;
    }

    if (USBD.IsSuperSpeed())
    {
        return SuperSpeed;
    }

    if (USBD.IsHighSpeed())
    {
        return HighSpeed;
    }

    return FullSpeed;
#else //TARGET_OS_WIN_XP
    // Using IsDeviceHighSpeed() method of USB_BUS_INTERFACE_USBDI_V1
    // Note: placing the interface on stack because we release it before return
    UNREFERENCED_PARAMETER(DriverObj);
    auto res = NoSpeed;
    USB_BUS_INTERFACE_USBDI_V1 iusbb;
    CWdmDeviceAccess wda(DevObj);
    NTSTATUS status = wda.QueryForInterface(
        USB_BUS_INTERFACE_USBDI_GUID,
        reinterpret_cast<INTERFACE &>(iusbb),
        sizeof(USB_BUS_INTERFACE_USBDI_V1),
        USB_BUSIF_USBDI_VERSION_1
        );

    if (NT_SUCCESS(status)) {
        ASSERT(iusbb.IsDeviceHighSpeed && iusbb.InterfaceDereference);
        res = iusbb.IsDeviceHighSpeed(iusbb.BusContext) ? HighSpeed : FullSpeed;
        iusbb.InterfaceDereference(iusbb.BusContext);
    }

    return res;
#endif //TARGET_OS_WIN_XP
}

bool UsbDkGetWdmDeviceIdentity(const PDEVICE_OBJECT PDO,
                               CObjHolder<CRegText> *DeviceID,
                               CObjHolder<CRegText> *InstanceID)
{
    CWdmDeviceAccess pdoAccess(PDO);

    if (DeviceID != nullptr)
    {
        *DeviceID = pdoAccess.GetDeviceID();
        if (!(*DeviceID) || (*DeviceID)->empty())
        {
            TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! No Device IDs read");
            return false;
        }
    }

    if (InstanceID != nullptr)
    {
        *InstanceID = pdoAccess.GetInstanceID();
        if (!(*InstanceID) || (*InstanceID)->empty())
        {
            TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! No Instance ID read");
            return false;
        }
    }

    return true;
}

NTSTATUS UsbDkSendUrbSynchronously(PDEVICE_OBJECT Target, URB &Urb)
{
    CIoControlIrp Irp;
    auto status = Irp.Create(Target, IOCTL_INTERNAL_USB_SUBMIT_URB);
    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during IOCTL IRP creation", status);
        return status;
    }

    Irp.Configure([&Urb] (PIO_STACK_LOCATION s)
                  { s->Parameters.Others.Argument1 = &Urb; });

    status = Irp.SendSynchronously();

    if (!NT_SUCCESS(status))
    {
        TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Send URB IRP Error %!STATUS!", status);
    }

    return status;
}
